-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: Support multi-GPU training via accelerate #5
base: main
Are you sure you want to change the base?
Conversation
Congratulation for this amazing work @bclavie 🤩, Thank you also for the documentation with the DataLoader. I'll run your branch in the following days to make sure everything run smoothly and then merge and release a new version. |
Thank you! Please do let me know if you run into any issues -- things are training fine right now but I'm using a pretty weird setup so there might still be some issues.
To be fair there's no code there at the moment, but I'm happy to update with mock data in a bit if you think it'd be useful! |
I don't have multiples GPUs (not even once) at home so I cannot mimic your environment. I propose to add the I also updated the documentation a bit in order to show how to create a dataset. All tests pass locally with the code from your branch and my updates, feel free to copy paste the code I commented. Also what version of transformers and accelerator are you using ? |
Hey, did you submit the comments? I can't see the suggested code anywhere, though it might be me being holiday-tired... Thank you for taking the time to look at this and improving it! I'm running I've ran some more experiments, and for full disclosure so far:
My feeling is that it might be actually be unsafe to merge as a "mature" feature at this stage, but doing so and labelling it experimental support could be useful? (as for neural-cherche itself, I really like the lightweight-ness of the lib, but currently I'm running into some issues where my models end up stuck in some kind of "compressed similarity" land and hard negatives are always extremely close to positives in similarity, which doesn't happen with the main ColBERT-codebase -- I'm training a ColBERT from scratch and will try to diagnose once I have more time!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Multi-GPU
Neural-Cherche is compatible with multiples GPUs training using [Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator). We can train every models of neural-cherche using GPUs. Here is a tutorial.
```python
import torch
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader
from neural_cherche import models, train
if __name__ == "__main__":
# We will need to wrap your training loop in a function to avoid multiprocessing issues.
accelerator = Accelerator()
save_each_epoch = True
model = models.SparseEmbed(
model_name_or_path="distilbert-base-uncased",
accelerate=True,
device=accelerator.device,
).to(accelerator.device)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
# Dataset creation using HuggingFace Datasets library.
dataset = Dataset.from_dict(
{
"anchors": ["anchor 1", "anchor 2", "anchor 3", "anchor 4"],
"positives": ["positive 1", "positive 2", "positive 3", "positive 4"],
"negatives": ["negative 1", "negative 2", "negative 3", "negative 4"],
}
)
# Convert your dataset to a DataLoader.
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Wrap model, optimizer, and dataloader in accelerator.
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
for epoch in range(2):
for batch in enumerate(data_loader):
# Batch is a triple like (anchors, positives, negatives)
anchors, positives, negatives = (
batch["anchors"],
batch["positives"],
batch["negatives"],
)
loss = train.train_sparse_embed(
model=model,
optimizer=optimizer,
anchor=anchors,
positive=positives,
negative=negatives,
threshold_flops=30,
accelerator=accelerator,
)
if accelerator.is_main_process and save_each_epoch:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
"checkpoint/epoch" + str(epoch),
)
# Save at the end of the training loop
# We check to make sure that only the main process will export the model
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained("checkpoint")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a clear example on how to create the dataset using HuggingFace Datasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got some troubles with position_ids extra parameters with DistilBERT pre-trained checkpoint but not with all-mpnet-base-v2 pre-trained checkpoint so I think it would be cool to keep the legacy code and add an accelerate attribute to models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import json
import os
from abc import ABC, abstractmethod
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForMaskedLM, AutoTokenizer
class Base(ABC, torch.nn.Module):
"""Base class from which all models inherit.
Parameters
----------
model_name_or_path
Path to the model or the model name.
device
Device to use for the model. CPU or CUDA.
extra_files_to_load
List of extra files to load.
accelerate
Use HuggingFace Accelerate.
kwargs
Additional parameters to the model.
"""
def __init__(
self,
model_name_or_path: str,
device: str = None,
extra_files_to_load: list[str] = [],
accelerate: bool = False,
**kwargs,
) -> None:
"""Initialize the model."""
super(Base, self).__init__()
if device is not None:
self.device = device
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.accelerate = accelerate
os.environ["TRANSFORMERS_CACHE"] = "."
self.model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, cache_dir="./", **kwargs
).to(self.device)
# Download linear layer if exists
for file in extra_files_to_load:
try:
_ = hf_hub_download(model_name_or_path, filename=file, cache_dir=".")
except:
pass
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, device=self.device, cache_dir="./", **kwargs
)
self.model.config.output_hidden_states = True
if os.path.exists(model_name_or_path):
# Local checkpoint
self.model_folder = model_name_or_path
else:
# HuggingFace checkpoint
model_folder = os.path.join(
f"models--{model_name_or_path}".replace("/", "--"), "snapshots"
)
snapshot = os.listdir(model_folder)[-1]
self.model_folder = os.path.join(model_folder, snapshot)
self.query_pad_token = self.tokenizer.mask_token
self.original_pad_token = self.tokenizer.pad_token
def _encode_accelerate(self, texts: list[str], **kwargs) -> tuple[torch.Tensor]:
"""Encode sentences with multiples gpus.
Parameters
----------
texts
List of sentences to encode.
References
----------
[Accelerate issue.](https://github.com/huggingface/accelerate/issues/97)
"""
encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
self.device
)
position_ids = (
torch.arange(0, encoded_input["input_ids"].size(1))
.expand((len(texts), -1))
.to(self.device)
)
output = self.model(**encoded_input, position_ids=position_ids)
return output.logits, output.hidden_states[-1]
def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""Encode sentences.
Parameters
----------
texts
List of sentences to encode.
"""
if self.accelerate:
return self._encode_accelerate(texts, **kwargs)
encoded_input = self.tokenizer.batch_encode_plus(
texts, return_tensors="pt", **kwargs
)
if self.device != "cpu":
encoded_input = {
key: value.to(self.device) for key, value in encoded_input.items()
}
output = self.model(**encoded_input)
return output.logits, output.hidden_states[-1]
@abstractmethod
def forward(self, *args, **kwargs):
"""Pytorch forward method."""
pass
@abstractmethod
def encode(self, *args, **kwargs):
"""Encode documents."""
pass
@abstractmethod
def scores(self, *args, **kwars):
"""Compute scores."""
pass
@abstractmethod
def save_pretrained(self, path: str):
"""Save model the model."""
pass
def save_tokenizer_accelerate(self, path: str) -> None:
"""Save tokenizer when using accelerate."""
tokenizer_config = {
k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
}
tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)
# dump vocab
self.tokenizer.save_vocabulary(path)
# save special tokens
special_tokens_file = os.path.join(path, "special_tokens_map.json")
with open(special_tokens_file, "w", encoding="utf-8") as file:
json.dump(
self.tokenizer.special_tokens_map,
file,
ensure_ascii=False,
indent=4,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the base class updated with a new save_tokenizer_accelerate and accelerate attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import json
import os
import torch
from .. import utils
from .base import Base
__all__ = ["ColBERT"]
class ColBERT(Base):
"""ColBERT model.
Parameters
----------
model_name_or_path
Path to the model or the model name.
embedding_size
Size of the embeddings in output of ColBERT model.
device
Device to use for the model. CPU or CUDA.
accelerate
Use HuggingFace Accelerate.
kwargs
Additional parameters to the SentenceTransformer model.
Examples
--------
>>> from neural_cherche import models
>>> import torch
>>> _ = torch.manual_seed(42)
>>> queries = ["Berlin", "Paris", "London"]
>>> documents = [
... "Berlin is the capital of Germany",
... "Paris is the capital of France and France is in Europe",
... "London is the capital of England",
... ]
>>> encoder = models.ColBERT(
... model_name_or_path="sentence-transformers/all-mpnet-base-v2",
... embedding_size=128,
... max_length_query=32,
... max_length_document=350,
... )
>>> scores = encoder.scores(
... queries=queries,
... documents=documents,
... )
>>> scores
tensor([22.9325, 19.8296, 20.8019])
>>> _ = encoder.save_pretrained("checkpoint", accelerate=False)
>>> encoder = models.ColBERT(
... model_name_or_path="checkpoint",
... embedding_size=64,
... device="cpu",
... )
>>> scores = encoder.scores(
... queries=queries,
... documents=documents,
... )
>>> scores
tensor([22.9325, 19.8296, 20.8019])
>>> embeddings = encoder(
... texts=queries,
... query_mode=True
... )
>>> embeddings["embeddings"].shape
torch.Size([3, 32, 128])
>>> embeddings = encoder(
... texts=queries,
... query_mode=False
... )
>>> embeddings["embeddings"].shape
torch.Size([3, 350, 128])
"""
def __init__(
self,
model_name_or_path: str,
embedding_size: int = 128,
device: str = None,
max_length_query: int = 32,
max_length_document: int = 350,
accelerate: bool = False,
**kwargs,
) -> None:
"""Initialize the model."""
super(ColBERT, self).__init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=["linear.pt", "metadata.json"],
accelerate=accelerate,
**kwargs,
)
self.max_length_query = max_length_query
self.max_length_document = max_length_document
self.embedding_size = embedding_size
if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
linear = torch.load(
os.path.join(self.model_folder, "linear.pt"), map_location=self.device
)
self.embedding_size = linear["weight"].shape[0]
in_features = linear["weight"].shape[1]
else:
with torch.no_grad():
_, embeddings = self._encode(texts=["test"])
in_features = embeddings.shape[2]
self.linear = torch.nn.Linear(
in_features=in_features,
out_features=self.embedding_size,
bias=False,
device=self.device,
)
if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
with open(os.path.join(self.model_folder, "metadata.json"), "r") as f:
metadata = json.load(f)
self.max_length_document = metadata["max_length_document"]
self.max_length_query = metadata["max_length_query"]
if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
self.linear.load_state_dict(linear)
def encode(
self,
texts: list[str],
truncation: bool = True,
add_special_tokens: bool = False,
query_mode: bool = True,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Encode documents
Parameters
----------
texts
List of sentences to encode.
truncation
Truncate the inputs.
add_special_tokens
Add special tokens.
max_length
Maximum length of the inputs.
"""
with torch.no_grad():
embeddings = self(
texts=texts,
truncation=truncation,
add_special_tokens=add_special_tokens,
query_mode=query_mode,
**kwargs,
)
return embeddings
def forward(
self,
texts: list[str],
query_mode: bool = True,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Pytorch forward method.
Parameters
----------
texts
List of sentences to encode.
query_mode
Wether to encode query or not.
"""
suffix = "[Q] " if query_mode else "[D] "
texts = [suffix + text for text in texts]
self.tokenizer.pad_token = (
self.query_pad_token if query_mode else self.original_pad_token
)
kwargs = {
"truncation": True,
"padding": "max_length",
"max_length": self.max_length_query
if query_mode
else self.max_length_document,
"add_special_tokens": True,
**kwargs,
}
_, embeddings = self._encode(texts=texts, **kwargs)
return {
"embeddings": torch.nn.functional.normalize(
self.linear(embeddings), p=2, dim=2
)
}
def scores(
self,
queries: list[str],
documents: list[str],
batch_size: int = 2,
tqdm_bar: bool = True,
**kwargs,
) -> torch.Tensor:
"""Score queries and documents.
Parameters
----------
queries
List of queries.
documents
List of documents.
batch_size
Batch size.
truncation
Truncate the inputs.
add_special_tokens
Add special tokens.
tqdm_bar
Show tqdm bar.
"""
list_scores = []
for batch_queries, batch_documents in zip(
utils.batchify(
X=queries,
batch_size=batch_size,
desc="Computing scores.",
tqdm_bar=tqdm_bar,
),
utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
):
queries_embeddings = self.encode(
texts=batch_queries,
query_mode=True,
**kwargs,
)
documents_embeddings = self.encode(
texts=batch_documents,
query_mode=False,
**kwargs,
)
late_interactions = torch.einsum(
"bsh,bth->bst",
queries_embeddings["embeddings"],
documents_embeddings["embeddings"],
)
late_interactions = torch.max(late_interactions, axis=2).values.sum(axis=1)
list_scores.append(late_interactions)
return torch.cat(list_scores, dim=0)
def save_pretrained(self, path: str) -> "ColBERT":
"""Save model the model.
Parameters
----------
path
Path to save the model.
"""
self.model.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
self.tokenizer.pad_token = self.original_pad_token
with open(os.path.join(path, "metadata.json"), "w") as f:
json.dump(
{
"max_length_query": self.max_length_query,
"max_length_document": self.max_length_document,
},
f,
)
if self.accelerate:
self.save_tokenizer_accelerate(path=path)
else:
self.tokenizer.save_pretrained(path)
return self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Colbert with the call to save_tokenizer_accelerate parent class :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import json
import os
import torch
from .. import utils
__all__ = ["SparseEmbed"]
from .splade import Splade
class SparseEmbed(Splade):
"""SparseEmbed model.
Parameters
----------
model_name_or_path
Path to the model or the model name. It should be a SentenceTransformer model.
embedding_size
Size of the embeddings in output of SparsEmbed model.
kwargs
Additional parameters to the pre-trained model.
Examples
--------
>>> from neural_cherche import models
>>> import torch
>>> _ = torch.manual_seed(42)
>>> device = "mps"
>>> model = models.SparseEmbed(
... model_name_or_path="distilbert-base-uncased",
... device=device,
... )
>>> queries_embeddings = model.encode(
... ["Sports", "Music"],
... )
>>> queries_embeddings["activations"].shape
torch.Size([2, 128])
>>> queries_embeddings["sparse_activations"].shape
torch.Size([2, 30522])
>>> queries_embeddings["embeddings"].shape
torch.Size([2, 128, 128])
>>> documents_embeddings = model.encode(
... ["Music is great.", "Sports is great."],
... query_mode=False,
... )
>>> documents_embeddings["activations"].shape
torch.Size([2, 256])
>>> documents_embeddings["sparse_activations"].shape
torch.Size([2, 30522])
>>> documents_embeddings["embeddings"].shape
torch.Size([2, 256, 128])
>>> model.scores(
... queries=["Sports", "Music"],
... documents=["Sports is great.", "Music is great."],
... batch_size=1,
... )
tensor([64.2330, 54.0180], device='mps:0')
>>> _ = model.save_pretrained("checkpoint")
>>> model = models.SparseEmbed(
... model_name_or_path="checkpoint",
... device="cpu",
... )
>>> model.scores(
... queries=["Sports", "Music"],
... documents=["Sports is great.", "Music is great."],
... batch_size=2,
... )
tensor([64.2330, 54.0180])
References
----------
1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065)
"""
def __init__(
self,
model_name_or_path: str = None,
embedding_size: int = 128,
max_length_query: int = 128,
max_length_document: int = 256,
device: str = None,
accelerate: bool = False,
**kwargs,
) -> None:
super(SparseEmbed, self).__init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=["linear.pt", "metadata.json"],
accelerate=accelerate,
**kwargs,
)
self.embedding_size = embedding_size
self.softmax = torch.nn.Softmax(dim=2).to(self.device)
if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
linear = torch.load(
os.path.join(self.model_folder, "linear.pt"), map_location=self.device
)
self.embedding_size = linear["weight"].shape[0]
in_features = linear["weight"].shape[1]
else:
with torch.no_grad():
_, embeddings = self._encode(texts=["test"])
in_features = embeddings.shape[2]
self.linear = torch.nn.Linear(
in_features=in_features,
out_features=self.embedding_size,
bias=False,
device=self.device,
)
if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
self.linear.load_state_dict(linear)
if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
metadata = json.load(file)
max_length_query = metadata["max_length_query"]
max_length_document = metadata["max_length_document"]
self.max_length_query = max_length_query
self.max_length_document = max_length_document
def forward(
self,
texts: list[str],
query_mode: bool = True,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Pytorch forward method.
Parameters
----------
texts
List of documents to encode.
query_mode
Whether to encode queries or documents.
"""
suffix = "[Q] " if query_mode else "[D] "
texts = [suffix + text for text in texts]
self.tokenizer.pad_token = (
self.query_pad_token if query_mode else self.original_pad_token
)
k_tokens = self.max_length_query if query_mode else self.max_length_document
logits, embeddings = self._encode(
texts=texts,
truncation=True,
padding="max_length",
max_length=k_tokens,
add_special_tokens=True,
**kwargs,
)
activations = self._update_activations(
**self._get_activation(logits=logits),
k_tokens=k_tokens,
)
attention = self._get_attention(
logits=logits,
activations=activations["activations"],
)
embeddings = torch.bmm(
attention,
embeddings,
)
return {
"embeddings": self.relu(self.linear(embeddings)),
"sparse_activations": activations["sparse_activations"],
"activations": activations["activations"],
}
def _get_attention(
self, logits: torch.Tensor, activations: torch.Tensor
) -> torch.Tensor:
"""Extract attention scores from MLM logits based on activated tokens."""
attention = logits.gather(
dim=2,
index=torch.stack(
[
torch.stack([token for _ in range(logits.shape[1])])
for token in activations
]
),
)
return self.softmax(attention)
def save_pretrained(
self,
path: str,
):
"""Save model the model."""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
if self.accelerate:
self.save_tokenizer_accelerate(path)
else:
self.tokenizer.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
fp=file,
obj={
"max_length_query": self.max_length_query,
"max_length_document": self.max_length_document,
},
indent=4,
)
return self
def scores(
self,
queries: list[str],
documents: list[str],
batch_size: int = 32,
tqdm_bar: bool = True,
**kwargs,
) -> torch.Tensor:
"""Compute similarity scores between queries and documents."""
dense_scores = []
for batch_queries, batch_documents in zip(
utils.batchify(
X=queries,
batch_size=batch_size,
desc="Computing scores.",
tqdm_bar=tqdm_bar,
),
utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
):
queries_embeddings = self.encode(
texts=batch_queries,
query_mode=True,
**kwargs,
)
documents_embeddings = self.encode(
texts=batch_documents,
query_mode=False,
**kwargs,
)
dense_scores.append(
utils.pairs_dense_scores(
queries_activations=queries_embeddings["activations"],
documents_activations=documents_embeddings["activations"],
queries_embeddings=queries_embeddings["embeddings"],
documents_embeddings=documents_embeddings["embeddings"],
)
)
return torch.cat(dense_scores, dim=0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparseEmbed with the call to save_tokenizer_accelerate parent class :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import json
import os
import string
import torch
from .. import utils
from .base import Base
__all__ = ["Splade"]
class Splade(Base):
"""SpladeV1 model.
Parameters
----------
tokenizer
HuggingFace Tokenizer.
model
HuggingFace AutoModelForMaskedLM.
kwargs
Additional parameters to the SentenceTransformer model.
Examples
--------
>>> from neural_cherche import models
>>> import torch
>>> _ = torch.manual_seed(42)
>>> model = models.Splade(
... model_name_or_path="distilbert-base-uncased",
... device="mps",
... )
>>> queries_activations = model.encode(
... ["Sports", "Music"],
... )
>>> documents_activations = model.encode(
... ["Music is great.", "Sports is great."],
... query_mode=False,
... )
>>> queries_activations["sparse_activations"].shape
torch.Size([2, 30522])
>>> model.scores(
... queries=["Sports", "Music"],
... documents=["Sports is great.", "Music is great."],
... batch_size=1
... )
tensor([318.1384, 271.8006], device='mps:0')
>>> _ = model.save_pretrained("checkpoint")
>>> model = models.Splade(
... model_name_or_path="checkpoint",
... device="mps",
... )
>>> model.scores(
... queries=["Sports", "Music"],
... documents=["Sports is great.", "Music is great."],
... batch_size=1
... )
tensor([318.1384, 271.8006], device='mps:0')
References
----------
1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)
"""
def __init__(
self,
model_name_or_path: str = None,
device: str = None,
max_length_query: int = 128,
max_length_document: int = 256,
extra_files_to_load: list[str] = ["metadata.json"],
accelerate: bool = False,
**kwargs,
) -> None:
super(Splade, self).__init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=extra_files_to_load,
accelerate=accelerate,
**kwargs,
)
self.relu = torch.nn.ReLU().to(self.device)
if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
metadata = json.load(file)
max_length_query = metadata["max_length_query"]
max_length_document = metadata["max_length_document"]
self.max_length_query = max_length_query
self.max_length_document = max_length_document
def encode(
self,
texts: list[str],
query_mode: bool = True,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Encode documents
Parameters
----------
texts
List of documents to encode.
truncation
Whether to truncate the documents.
padding
Whether to pad the documents.
max_length
Maximum length of the documents.
"""
with torch.no_grad():
return self(
texts=texts,
query_mode=query_mode,
**kwargs,
)
def decode(
self,
sparse_activations: torch.Tensor,
clean_up_tokenization_spaces: bool = False,
skip_special_tokens: bool = True,
k_tokens: int = 96,
) -> list[str]:
"""Decode activated tokens ids where activated value > 0.
Parameters
----------
sparse_activations
Activated tokens.
clean_up_tokenization_spaces
Whether to clean up the tokenization spaces.
skip_special_tokens
Whether to skip special tokens.
k_tokens
Number of tokens to keep.
"""
activations = self._filter_activations(
sparse_activations=sparse_activations, k_tokens=k_tokens
)
# Decode
return [
" ".join(
activation.translate(str.maketrans("", "", string.punctuation)).split()
)
for activation in self.tokenizer.batch_decode(
activations,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
skip_special_tokens=skip_special_tokens,
)
]
def forward(
self,
texts: list[str],
query_mode: bool,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Pytorch forward method.
Parameters
----------
texts
List of documents to encode.
query_mode
Whether to encode queries or documents.
"""
suffix = "[Q] " if query_mode else "[D] "
texts = [suffix + text for text in texts]
self.tokenizer.pad_token = (
self.query_pad_token if query_mode else self.original_pad_token
)
k_tokens = self.max_length_query if query_mode else self.max_length_document
logits, _ = self._encode(
texts=texts,
truncation=True,
padding="max_length",
max_length=k_tokens,
add_special_tokens=True,
**kwargs,
)
activations = self._get_activation(logits=logits)
activations = self._update_activations(
**activations,
k_tokens=k_tokens,
)
return {"sparse_activations": activations["sparse_activations"]}
def save_pretrained(
self,
path: str,
):
"""Save model the model.
Parameters
----------
path
Path to save the model.
"""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
if self.accelerate:
self.save_tokenizer_accelerate(path)
else:
self.tokenizer.save_pretrained(path)
with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
fp=file,
obj={
"max_length_query": self.max_length_query,
"max_length_document": self.max_length_document,
},
indent=4,
)
return self
def scores(
self,
queries: list[str],
documents: list[str],
batch_size: int = 32,
tqdm_bar: bool = True,
**kwargs,
) -> torch.Tensor:
"""Compute similarity scores between queries and documents.
Parameters
----------
queries
List of queries.
documents
List of documents.
batch_size
Batch size.
tqdm_bar
Show a progress bar.
"""
sparse_scores = []
for batch_queries, batch_documents in zip(
utils.batchify(
X=queries,
batch_size=batch_size,
desc="Computing scores.",
tqdm_bar=tqdm_bar,
),
utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
):
queries_embeddings = self.encode(
batch_queries,
query_mode=True,
**kwargs,
)
documents_embeddings = self.encode(
batch_documents,
query_mode=False,
**kwargs,
)
sparse_scores.append(
torch.sum(
queries_embeddings["sparse_activations"]
* documents_embeddings["sparse_activations"],
axis=1,
)
)
return torch.cat(sparse_scores, dim=0)
def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]:
"""Returns activated tokens."""
return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}
def _filter_activations(
self, sparse_activations: torch.Tensor, k_tokens: int
) -> list[torch.Tensor]:
"""Among the set of activations, select the ones with a score > 0."""
scores, activations = torch.topk(input=sparse_activations, k=k_tokens, dim=-1)
return [
torch.index_select(
activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0]
)
for score, activation in zip(scores, activations)
]
def _update_activations(
self, sparse_activations: torch.Tensor, k_tokens: int
) -> torch.Tensor:
"""Returns activated tokens."""
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
updated_sparse_activations = sparse_activations * zero_tensor.scatter(
dim=1, index=activations.long(), value=1
)
return {
"activations": activations,
"sparse_activations": updated_sparse_activations,
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Splade with the call to save_tokenizer_accelerate parent class :)
Ahah missed this, sorry.
It could come from the loss function which is quite simple? Would love to get your feedback on this if you find anything. Overall, I think it's fine to push your work on Master if we use the flag |
No worries, I've applied the changes 1:1, except for the tutorial page (added that support is partial/in-progress, so people don't get the impression it's fully supported yet!)
I think that's probably it... I'll definitely try and figure exactly what component has the biggest impact once I've got some more time |
Hey! Great work on the library. I've been playing with it and ran into a few issues with in-place operations when trying to train on multiple GPUs:
Setting device this way also really doesn't play nice with the default tokeniser export, so there's a workaround to export the files individually rather than risky JSON decoding.
I've also added a doc page to show how simple it is to parallelise training with just those few changes and some very slightly code modifications in a trading script.